Skip to content

[V1][Kernel] Flashinfer HND KV cache layout #19280

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: main
Choose a base branch
from

Conversation

NickLucche
Copy link
Contributor

@NickLucche NickLucche commented Jun 6, 2025

Follow up PR to #18775, again porting over functionality from V0 (ref #16605).

This PR will enable the use of FlashInfer with a HND cache layout in V1.
Among the most immediate benefits, this PR is a prerequisite to enabling heterogeneous TP support for disaggregated prefill-decode setup, optimizing the layout for xfers.

Test with:

FLASHINFER_KV_CACHE_LAYOUT=HND VLLM_ATTENTION_BACKEND=FLASHINFER vllm serve Qwen/Qwen3-0.6B  

A simple benchmark:

# FLASHINFER_KV_CACHE_LAYOUT=HND  VLLM_ATTENTION_BACKEND=FLASHINFER vllm serve Qwen/Qwen3-14B --disable-log-requests

NHD

============ Serving Benchmark Result ============
Successful requests:                     982       
Benchmark duration (s):                  94.63     
Total input tokens:                      1745235   
Total generated tokens:                  125696    
Request throughput (req/s):              10.38     
Output token throughput (tok/s):         1328.26   
Total Token throughput (tok/s):          19770.61  
---------------Time to First Token----------------
Mean TTFT (ms):                          45294.13  
Median TTFT (ms):                        46790.26  
P99 TTFT (ms):                           91790.54  
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          106.51    
Median TPOT (ms):                        113.65    
P99 TPOT (ms):                           131.63    
---------------Inter-token Latency----------------
Mean ITL (ms):                           106.51    
Median ITL (ms):                         30.62     
P99 ITL (ms):                            378.72    
==================================================

HDN (this PR)

============ Serving Benchmark Result ============
Successful requests:                     982       
Benchmark duration (s):                  95.03     
Total input tokens:                      1745219   
Total generated tokens:                  125696    
Request throughput (req/s):              10.33     
Output token throughput (tok/s):         1322.67   
Total Token throughput (tok/s):          19687.22  
---------------Time to First Token----------------
Mean TTFT (ms):                          45028.82  
Median TTFT (ms):                        46518.14  
P99 TTFT (ms):                           92111.68  
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          106.67    
Median TPOT (ms):                        111.74    
P99 TPOT (ms):                           132.01    
---------------Inter-token Latency----------------
Mean ITL (ms):                           106.67    
Median ITL (ms):                         30.41     
P99 ITL (ms):                            378.99    
==================================================

cc @mgoin

Signed-off-by: nicklucche <[email protected]>
Signed-off-by: nicklucche <[email protected]>
Signed-off-by: NickLucche <[email protected]>
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello @NickLucche, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

Summary of Changes

Hello team, gemini-code-assist here with a summary of this pull request. This PR, authored by @NickLucche, focuses on enabling the HND (Head, Number of blocks, Dimension) KV cache layout specifically for the FlashInfer attention backend in V1 of vLLM. This is a follow-up to previous work and is intended to be a prerequisite for features like heterogeneous Tensor Parallelism (TP) with a disaggregated prefill-decode setup, as the HND layout is noted to be better optimized for data transfers in such scenarios. The changes involve modifying the FlashInfer backend to dynamically determine the desired KV cache layout (either via an environment variable or the VLLM config) and applying the necessary tensor permutation before interacting with the FlashInfer kernels.

Highlights

  • FlashInfer HND KV Cache Layout: Introduces support for the HND (Head, Number of blocks, Dimension) KV cache layout when using the FlashInfer attention backend in V1.
  • Configurable KV Cache Layout: Allows the KV cache layout for FlashInfer to be specified via the FLASHINFER_KV_CACHE_LAYOUT environment variable, overriding the default or configured layout.
  • Tensor Permutation for Layout: Implements logic within the FlashInfer backend to permute the KV cache tensor to the desired HND layout before passing it to the FlashInfer kernels.
  • Enabling Disaggregated PD: This change is highlighted as a necessary step towards enabling heterogeneous TP and optimizing data transfers for disaggregated prefill-decode setups.

Changelog

Click here to see the changelog
  • vllm/distributed/kv_transfer/kv_connector/utils.py
    • Removed a redundant blank line (L6).
    • Added a note explaining that HND layout is used for faster transfer in disaggregated PD with NIXL (L96-97).
    • Improved the check in get_kv_connector_cache_layout to also verify if kv_config is not None before defaulting to NHD (L100).
  • vllm/v1/attention/backends/flashinfer.py
    • Imported functools and os (L6-7).
    • Imported get_kv_connector_cache_layout from vllm.distributed.kv_transfer.kv_connector.utils (L21-22).
    • Added a module-level variable FLASHINFER_KV_CACHE_LAYOUT to read the environment variable (L35-36).
    • Added a cached function get_flashinfer_kv_cache_layout to determine the layout, prioritizing the environment variable (L41-51).
    • Added a static method get_kv_cache_stride_order to FlashInferBackend to return the permutation tuple based on the determined layout (L87-98).
    • Updated the initialization of BatchPrefillWithPagedKVCacheWrapper to use the layout from get_flashinfer_kv_cache_layout() instead of hardcoded "NHD" (L324).
    • Updated the initialization of BatchDecodeWithPagedKVCacheWrapper to use the layout from get_flashinfer_kv_cache_layout() instead of hardcoded "NHD" (L337).
    • Updated the initialization of MultiLevelCascadeAttentionWrapper to use the layout from get_flashinfer_kv_cache_layout() instead of hardcoded "NHD" (L345).
    • Called FlashInferBackend.get_kv_cache_stride_order() in the forward method to get the required permutation (L643).
    • Applied .permute(*stride_order) to the kv_cache tensor before passing it to prefill_wrapper.run (L658).
    • Applied .permute(*stride_order) to the kv_cache tensor before passing it to decode_wrapper.run (L674).
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@mergify mergify bot added the v1 label Jun 6, 2025
Copy link

github-actions bot commented Jun 6, 2025

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request successfully enables the HND KV cache layout for FlashInfer in V1, enhancing flexibility and potentially performance for disaggregated prefill-decode setups. The changes are well-structured, introducing a clear mechanism for determining the cache layout (via environment variable or vLLM config) and applying it consistently.

Key improvements include:

  • Robust handling of KV cache layout determination with user overrides.
  • Centralized logic for stride order based on the selected layout.
  • Consistent application of the layout across prefill, decode, and cascade attention wrappers.

A minor issue with error message formatting was found, which should be addressed. Overall, the PR is a good step towards optimizing KV cache transfers.

Summary of Findings

  • Incorrect ValueError Formatting: In vllm/v1/attention/backends/flashinfer.py, the ValueError raised for an unknown cache layout uses incorrect string formatting. This would lead to a TypeError if the ValueError itself is raised. It should be updated to use f-string formatting.
  • Enhanced Configuration Robustness: In vllm/distributed/kv_transfer/kv_connector/utils.py, an additional null check for kv_config was added. This improves the robustness of get_kv_connector_cache_layout by preventing potential AttributeError if kv_config is None.
  • Flexible KV Cache Layout Configuration: The PR introduces the FLASHINFER_KV_CACHE_LAYOUT environment variable, allowing users to override the KV cache layout for FlashInfer. This falls back to the vLLM configuration if the environment variable is not set, providing good flexibility.
  • Efficient Layout Determination: The use of functools.lru_cache on the get_flashinfer_kv_cache_layout function is a good optimization, preventing redundant computations of the cache layout within a single forward pass.

Merge Readiness

The pull request is generally in good shape and implements the intended functionality effectively. There is one medium-severity issue regarding ValueError formatting that should be addressed before merging. Once this is fixed, the PR should be ready for merge. I am unable to approve the pull request myself; please ensure it is reviewed and approved by other maintainers.

elif cache_layout == "HND":
stride_order = (0, 1, 3, 2, 4)
else:
raise ValueError("Unknown cache layout format %s.", cache_layout)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The string formatting for the ValueError message appears to be incorrect. It seems like it's attempting to use %s style formatting, but cache_layout is not passed as an argument to the ValueError constructor for formatting. This would result in a TypeError if this ValueError is actually raised, as ValueError would try to interpret cache_layout as a formatting argument.

Could you update this to use an f-string for clarity and correctness?

Suggested change
raise ValueError("Unknown cache layout format %s.", cache_layout)
raise ValueError(f"Unknown cache layout format {cache_layout}.")

Signed-off-by: NickLucche <[email protected]>
@robertgshaw2-redhat
Copy link
Collaborator

Did you do any performance tests to see how this impacts E2E performance for non-PD setups?

I think a throughput benchmark is warranted at least

@@ -28,10 +32,25 @@
from vllm.v1.worker.gpu_model_runner import GPUModelRunner

FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024
FLASHINFER_KV_CACHE_LAYOUT: str = os.getenv("FLASHINFER_KV_CACHE_LAYOUT",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should be defined like all other env variables in vLLM. envs.py and VLLM_FLASHINFER_KV_CACHE_LAYOUT

I also think we should have the env variable be VLLM_KV_CACHE_LAYOUT rather than having a specific one for each attention backend type.

@NickLucche
Copy link
Contributor Author

I compared a few models (llama3/qwen3), didn't notice a huge impact tbh. I've added a simple H100 benchmark in the description.

I should also probably tag @wenscarl to elaborate on HND for flashinfer.

Copy link

mergify bot commented Jun 9, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @NickLucche.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Jun 9, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants